#!/usr/bin/python

# make the clustering class-specific

import sys,os,re,glob,math,glob,signal,traceback,sqlite3
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from scipy.ndimage import interpolation,measurements
from pylab import *
from optparse import OptionParser
from multiprocessing import Pool
import ocrolib
from ocrolib import number_of_processors,fstutils,die,docproc
from scipy import stats
from ocrolib import dbhelper,ocroio,ocrofst

signal.signal(signal.SIGINT,lambda *args:sys.exit(1))

parser = OptionParser("""
usage: 
%prog [-s gt] [-l langmod] [options] file.fst ...

    Align each file.fst with language model and output the result into
    file.cseg.png, file.costs, and file.txt.  With -s gt, writes the
    results into file.cseg.gt.png, file.gt.costs, and file.gt.txt 
    instead.

%prog [-s gt] [-g extension] [options] *.fst

    Align with groundtruth files.  For each x.fst, looks for the ground truth
    in x.extension

%prog [-s gt] [-p] [options] *.gt.txt

    Align with page-level ground truth.  For each page.gt.txt, looks in
    page/??????.png for the image data.

Aligns recognition lattices (stored in .fst files) with language models and/or
ground truth.  If -g extension is given, each foo.fst file is aligned with a
corresponding foo.extension file.  

Language models and ground truth files may be given either as FST files or 
text files.  If they are given as text files, an FST is constructed by using
the fstutils.add_line_to_fst function.

""")


parser.add_option("-l","--langmod",help="language model",default=None)
parser.add_option("-x","--extract",help="extract characters",default=None)
parser.add_option("-X","--noextract",help="don't actually write files",action="store_true")
parser.add_option("-L","--ligatures",help="output ligatures in aligned text",action="store_true")
parser.add_option("-B","--beam",help="size of beam",type=int,default=1000)
parser.add_option("-N","--noligatures",help="don't expand ligature notation in transcriptions",action="store_true")
parser.add_option("-g","--gt",help="extension for ground truth",default=None)
parser.add_option("-p","--pagegt",help="arguments are page ground truth",action="store_true")
parser.add_option("-s","--suffix",help="output suffix for writing result",default=None)
parser.add_option("-O","--overwrite",help="overwrite outputs",action="store_true")
parser.add_option("-P","--perc",help="percentile for reporting statistics",type=float,default=90.0)
parser.add_option("-M","--maxperc",help="maximum cost at percentile",type=float,default=2.0)
parser.add_option("-A","--maxavg",help="maximum average cost",type=float,default=3.0)
parser.add_option("-a","--aligner",help="choose the aligner to use",default="=ocrolib.fstutils.DefaultAligner()")
parser.add_option("-Y","--debug_cls",help="select classes for debugging",default="")
parser.add_option("-Q","--parallel",type=int,default=number_of_processors(),help="number of parallel processes to use")
parser.add_option("-c","--cont",help="continue on error",action="store_true")
parser.add_option("-E","--showerrs",help="show errors",action="store_true")
parser.add_option("-D","--Display",help="display",action="store_true")
(options,args) = parser.parse_args()

if len(args)==0:
    parser.print_help()
    sys.exit(0)

if options.showerrs:
    options.parallel=0

if options.extract is not None and os.path.exists(options.extract):
    print "output file already exists:",options.extract
    sys.exit(1)

# Database handling is a bit tricky because of concurrency.
# First create the database and update the columns if necessary.
# Then close the db descriptor.  Reopen it in each multiprocessing
# subtask if necessary.

db = None

if options.extract:
    db = sqlite3.connect(options.extract,timeout=600.0)
    dbhelper.charcolumns(db,"chars")
    db.commit()
    db.close()
    db = None

def dbopen():
    global db
    if db is not None: return
    if options.extract:
        db = sqlite3.connect(options.extract,timeout=600.0)
        db.row_factory = dbhelper.DbRow
        db.text_factory = sqlite3.OptimizedUnicode
        db.execute("pragma synchronous=0")
        db.commit()

# We can potentially use different kinds of aligners; need an option to configure
# this.

# aligner = fstutils.DefaultAligner()
aligner = ocrolib.load_component(options.aligner)

lfile = None
lfst = None

def safe_align1(t):
    try:
        align1(t)
    except e:
        traceback.print_exc()

def align1(t):
    dbopen()
    global lfile,lfst
    (fname,lmodel) = t
    ocrolib.fcleanup(fname,options.suffix,["txt","costs"])

    try:
        fst = ocrofst.OcroFST()
        fst.load(ocrolib.ffind(fname,"fst"))
        rseg = ocroio.read_line_segmentation(ocrolib.ffind(fname,"rseg"))

        if lmodel!=lfile:
            lfile = lmodel
            nolig = not options.noligatures
            if lmodel.endswith(".fst"):
                lfst = ocrofst.OcroFST().load(lmodel)
            else:
                lfst = aligner.ocroFstForFile(lmodel)

        if lfst is None:
            sys.exit(1)
        
        r = ocrolib.compute_alignment(fst,rseg,lfst)

        cseg = r.cseg
        assert amin(cseg)==0,"amin(cseg)!=0 (%d,%d)"%(amin(cseg),amax(cseg))
        costs = r.costs

        # first work on the list output; here, one list element should
        # correspond to each cost
        result = r.output_l
        assert len(result)==len(costs),\
            "output length %d differs from cost length %d"%(len(result),len(costs))
        assert amax(r.cseg)<len(result),\
            "amax(r.cseg) %d not consistent with output length %d"%(amax(r.cseg),len(r.output_l))

        # if there are spaces at the end, trim them (since they will never have a corresponding cseg)
        while len(result)>0 and result[-1]==" ":
            result = result[:-1]
            costs = costs[:-1]
        if amax(r.cseg)+1!=len(result):
            # Note: we can't distinguish deletions at the end from other misalignment problems;
            # these fields are probably OK, but for now we just skip them.  Later there will
            # be an option to output them anyway.
            print "%s: segmentation and transcript don't align; probably deletion at end"%fname
            if options.showerrs:
                ion()
                axis = subplot(111)
                axis.imshow(r.cseg,cmap=cm.gist_stern) # cmap=cm.flag
                objects = measurements.find_objects(r.cseg)
                xs = [0.5*(p[1].start+p[1].stop) if p is not None else None for p in objects]
                xs = [0]+xs
                h,w = r.cseg.shape
                for i in range(min(len(result),len(xs))):
                    # print "[%s]"%result
                    if xs[i] is None: continue
                    axis.text(xs[i],h,result[i],color="red",size=10)
                for i in range(min(len(result),len(xs)),len(result)):
                    print "extra",i,result[i]
                show()
                raw_input()
            return

        if options.ligatures:
            result = fstutils.implode_transcription(result_l),
        else:
            result = "".join(result)

        perc = stats.scoreatpercentile(costs,options.perc)
        avg = mean(costs)
        skip = (perc > options.maxperc or avg>options.maxavg)

        if len(r.output)==0:
            print "* %s: *****"%(fname,)
            return
        else:
            print "%-1s %s: %5.2f %5.2f: %s"%("*" if skip else " ",fname,perc,avg,result)
            if skip: return

        if not options.noextract:
            cseg_file = ocrolib.fvariant(fname,"cseg",options.suffix)
            if not options.overwrite:
                if os.path.exists(cseg_file): die("%s: already exists",cseg_file)
            ocrolib.write_line_segmentation(cseg_file,cseg)
            ocrolib.write_text(ocrolib.fvariant(fname,"txt",options.suffix),result)
            with ocrolib.fopen(fname,"costs",options.suffix,mode="w") as stream:
                for i in range(len(costs)):
                    stream.write("%d %g\n"%(i,costs[i]))

        if db is not None:
            # for i in range(len(r.segs)): print "%6d %-8s %s"%(i,r.segs[i],r.output_l[i])
            line = ocrolib.read_image_gray(ocrolib.ffind(fname,"png"))
            line = amax(line)-line
            lgeo = docproc.seg_geometry(rseg)
            grouper = ocrolib.Grouper()
            grouper.setSegmentation(rseg)
            for i in range(grouper.length()):
                raw,mask = grouper.extractWithMask(line,i,dtype='B')
                if options.Display: 
                    ion(); clf(); gray(); imshow(raw); ginput(1,0.01)
                start = grouper.start(i)
                end = grouper.end(i)
                bbox = grouper.boundingBox(i)
                y0,x0,y1,x1 = bbox
                rel = docproc.rel_char_geom((y0,y1,x0,x1),lgeo)
                ry,rw,rh = rel
                assert rw>0 and rh>0
                if (start,end) in r.segs:
                    index = r.segs.index((start,end))
                    cls = r.output_l[index]
                    cost = r.costs[index]
                else:
                    cls = "~"
                    cost = 0.0
                if cls in options.debug_cls:
                    print "debug",start,end,"cls",cls,"cost",cost,\
                        "y %.2f w %.2f h %.2f"%(rel[0],rel[1],rel[2])
                dbhelper.dbinsert(db,"chars",
                               image=dbhelper.image2blob(raw),
                               cost=float(cost),
                               cls=cls,
                               count=1,
                               file= fname,
                               lgeo="%g %g %g"%lgeo,
                               rel="%g %g %g"%rel,
                               bbox="%d %d %d %d"%bbox)
    except IOError,e:
        print "# not found:",e
    except:
        if not options.cont: raise
        traceback.print_exc()
    if db is not None:
        db.commit()


jobs = []

if options.pagegt is not None:
    if len(args)==1 and os.path.isdir(args[0]):
        args = glob.glob(args[0]+"/*.gt.txt")
    for arg in args:
        if not os.path.exists(arg):
            print "# %s: not found"%arg
            continue
        base,_ = ocrolib.allsplitext(arg)
        if not os.path.exists(base) or not os.path.isdir(base):
            print "?"+base,
            sys.stdout.flush()
            continue
        lines = glob.glob(base+"/??????.png")
        for line in lines:
            jobs.append((line,arg))
    print
elif options.langmod is not None:
    for arg in args:
        jobs.append((arg,options.langmod))
elif options.gt is not None:
    if len(args)==1 and (os.path.isdir(args[0]) or os.path.islink(args[0])):
        args = glob.glob(args[0]+"/????/??????.png")
    for arg in args:
        path,ext = ocrolib.allsplitext(arg)
        p = path+options.gt
        if not os.path.exists(arg):
            print arg,"not found"
            continue
        if not os.path.exists(p): 
            print p,"not found"
            continue
        jobs.append((arg,p))
else:
    raise Exception("you need to specify what kind of groundtruth you want to align with (-p, -l, -g)")

if options.parallel<2:
    for arg in jobs: align1(arg)
else:
    pool = Pool(processes=options.parallel)
    result = pool.map(safe_align1,jobs)
